from IO import CageViewer
from IO import EthoVision
from IO import EssayPart
import IO
import numpy as np
import logging
import pandas as pd
import yaml
import copy
import os.path
import glob

def buildDic(oldD, newD, columns):
    for k, v in newD.items():
        for column in v:
            if column in columns:
                oldD[column] = k
            else:
                logging.warning('Column %d not present in data. Will skip it!', column)
    return oldD

def addByIndex(df, group, column):
    for I in group.index:
        cn = I + '_' + column
        df.loc[0, cn] = group[column][I]

def transposeByColumn(data, prefix_column, target_column):
    """
    Generate a new DataFrame based on data that will concatenate(axis=1) all target_columns,
    giving them a prefix based on the value of a the column prefix_column
    """
    ts = [pd.DataFrame(None)]
    data.index = len(data) * [0]  # Make sure all data rows have the same index value, or concate will fail
    for row in data.iterrows():
        prefix = row[1][prefix_column]
        colnames = [prefix + '_' + col for col in target_column]
        t = pd.DataFrame(data=row[1][target_column]).T
        t.columns = colnames
        ts.append(t.astype('float'))

    return pd.concat(ts, axis=1)


def trialCollapse(group, periodColumns, trial_agg):
    """
    Used in groupby to collapse a Trial
    Will add a column containing the extendedTrialType
    """
    r, c = group.shape
    if group.iloc[0]['TrialType'] == 'Go':
        if r == 2:
            eType = 'Premature'
        elif r == 3:
            eType = 'Go_Omission'
        elif r == 4:
            eType = 'Go_Correct'
        else:
            eType = '???'
    else:
        if r == 2:
            eType = 'Premature'
        elif r == 3:
            eType = 'FalseAlarm'
        elif r == 4:
            if group.query('PeriodType == "Reward"')['Lick_OnSet'].any() > 0:
                eType = 'NG_Correct'
            else:
                eType = 'NG_Omission'
        else:
            eType = '???'

    T = transposeByColumn(group, 'PeriodType', periodColumns)
    T2 = group.groupby(['MouseID', 'TrialCnt']).agg(trial_agg)
    T2.set_index('TrialType', inplace=True)
    T2.insert(0, 'ExtendedTrialType', eType)
    T.index = T2.index
    return pd.concat([T2, T], axis=1)

def calculateResponseRate(data, colname):
    if 'ExtendedTrialType' not in data.columns:
        logging.warning('Cannot calculate ResponseRate because the extendedTrialType column is missing!')
        Go_RR = NG_RR = np.nan
    else:
        try:
            G1 = data.query('ExtendedTrialType == "Go_Correct"')[colname].values[0]
        except IndexError:
            G1 = 0
        try:
            G2 = data.query('ExtendedTrialType == "Go_Omission"')[colname].values[0]
        except IndexError:
            G2 = 0

        try:
            Go_RR = G1 / (G1 + G2)
        except ZeroDivisionError:
            Go_RR = 0

        try:
            G1 = data.query('ExtendedTrialType == "FalseAlarm"')[colname].values[0]
        except IndexError:
            G1 = 0
        try:
            G2 = data.query('ExtendedTrialType == "NG_Omission"')[colname].values[0]
        except IndexError:
            G2 = 0
        try:
            G3 = data.query('ExtendedTrialType == "NG_Correct"')[colname].values[0]
        except IndexError:
            G3 = 0

        try:
            NG_RR = G1 / (G1 + G2 + G3)
        except ZeroDivisionError:
            NG_RR = 0

    T = pd.DataFrame([Go_RR, NG_RR]).T
    T.columns = ['Go_ResponseRate', 'NG_ResponseRate']
    return T



def importConfigFile(yamlFile):
    try:
        cfg = {}
        with open(yamlFile) as f:
            cfg = yaml.load(f.read())
    except Exception as e:
        logging.error('Loading config file failed! %s', e)

    ncfg = {}
    for section, section_cont in cfg.items():
        how = {}
        for column in section_cont:
            for k, v in column.items():
                if v == 'mostcommon':
                    column[k] = EssayPart.EssayPart.lambda_most_common
            how.update(column)
        ncfg[section] = how

    return ncfg

def calculateExtendedTable(table, groupbySelection, aggOps):
    #Calculate Response Rates
    T = table.copy()
    T.reset_index(inplace=True)
    grp = groupbySelection
    S3 = T[grp + ['Duration']].groupby(grp).count()
    S3.reset_index(inplace=True)
    a = lambda x: calculateResponseRate(x, 'Duration')
    responserateTable = S3.groupby('MouseID').apply(a)
    responserateTable.index = responserateTable.index.levels[0]  # Remove artificial index introduced by apply

    #Stage4
    S4 = T.groupby(grp).agg(aggOps)

    extTrialTable = S4.join(responserateTable)
    return extTrialTable, responserateTable

def TrialBatchAnalysis(masterDirectory, yamlFile='TrialAnalysis.yaml', **kwargs):
    sessionDirs = glob.glob(os.path.join(masterDirectory, '*'))
    for sessionDir in sessionDirs:
        try:
            logging.info('Analysing session %s ...', sessionDir)
            TrialAnalysis(sessionDir, yamlFile, **kwargs)
            logging.info('Finished analysis of %s', sessionDir)
        except Exception as e:
            logging.warning('Session Analysis failed! %s', e)
    logging.info('Finished all sessions in %s', masterDirectory)

def TrialAnalysis(sessionDirectory, yamlFile='TrialAnalysis.yaml', exportFormat='csv', exportFullTableAsCSV=False):

    outDir = os.path.join(sessionDirectory, 'analysis')
    logging.info('Loading Session data from %s', sessionDirectory)
    try:
        try:
            os.mkdir(outDir)
        except FileExistsError:
            logging.warning('Analysis folder already present! Overwriting all previous files!')
        except PermissionError:
            logging.error('Cannot create output directory! Check permission!')
            return

        try:
            logger = logging.getLogger()
            fh = logging.FileHandler(os.path.join(outDir, 'Log.txt'), 'w')
            logger.addHandler(fh)
        except Exception as e:
            logging.error('Setting up log file failed! %s', e)
            logger = None
            fh = None

        T = IO.loadSession(sessionDirectory)
        pT, tT, etT, T2 = TableAnalysis(T, yamlFile)

        sP = os.path.join(outDir, 'FullTable.hdf')
        logging.info('Storing Session Table to %s ...', sP)
        T2.to_hdf(sP, 'T')
        if exportFullTableAsCSV:
            T2.to_csv(os.path.join(outDir, 'FullTable.csv'))

        if exportFormat == 'csv':
            logging.info('Writing Tables as CSV to %s ...', outDir)
            pT.to_csv(os.path.join(outDir, 'PeriodTable.csv'))
            tT.to_csv(os.path.join(outDir, 'TrialTable.csv'))
            etT[0].to_csv(os.path.join(outDir, 'ExtTrialTableByExtendedTrialType.csv'))
            etT[1].to_csv(os.path.join(outDir, 'ExtTrialTableByTrialType.csv'))
            etT[2].to_csv(os.path.join(outDir, 'ExtTrialTableByMouse.csv'))
        elif exportFormat == 'xlsx':
            try:
                logging.info('Writing Tables as XLSX to %s ...', outDir)
                pT.to_excel(os.path.join(outDir, 'PeriodTable.xlsx'))
                tT.to_excel(os.path.join(outDir, 'TrialTable.xlsx'))
                writer = pd.ExcelWriter(os.path.join(outDir, 'ExtendedTrialTable.xlsx'))
                etT[0].to_excel(writer, 'ByExtendedTrialType')
                etT[1].to_excel(writer, 'ByTrialType')
                etT[2].to_excel(writer, 'ByMouse')
                writer.save()
            except Exception as e:
                logging.error('Writing output files failed! %s', e)
        else:
            logging.warning('Unknown export format!')
    except Exception as e:
        logging.error('Failed to perform Analysis! %s', e)
    logging.info('Done processing session!')

    if fh:
        logger.removeHandler(fh)

def TableAnalysis(Table, yamlFile):
    Table = Table.copy()
    #Rename the PeriodTypes
    Table['PeriodType'].replace(['.*Precue', '.*Cue', '.*Reward', '.*ITI'], ['Precue', 'Cue', 'Reward', 'ITI'], inplace=True, regex=True)

    cfg = importConfigFile(yamlFile)
    trial_agg = cfg['TrialAggregation']
    period_agg = cfg['PeriodAggregation']
    eTT_agg = cfg['ExtendedTrialAggregation']
    multiplyColumns = cfg['Multiply']

    extperiod_agg = copy.copy(trial_agg)
    for cn, fun in period_agg.items():
        for pT in Table['PeriodType'].unique():
            if fun is None:
                extperiod_agg[pT + '_' + cn] = trial_agg[cn]
            else:
                extperiod_agg[pT + '_' + cn] = fun
        extperiod_agg[cn] = trial_agg[cn]

    #Some fields will be handled differently in calculateExtendedTable
    extperiod_agg.update(eTT_agg)

    #Generate a template dataframe
    colNames = list(period_agg.keys())

    #If there are some columns not present, add them artificially!
    referencedColumns = np.unique(list(trial_agg.keys()) + list(period_agg.keys()) + list(eTT_agg.keys()))
    for col in pd.Index(referencedColumns).difference(Table.columns):
        logging.warning('Adding column %s with NaN', col)
        Table.insert(len(Table.columns), col, np.nan)

    #Make it possible to sum up mobility in time
    for k, v in multiplyColumns.items():
        Table[k] = Table[k] * Table[v]

    #Stage 1
    periodTable = Table.groupby(['MouseID', 'TrialCnt', 'PeriodType']).agg(trial_agg)

    #Stage 2
    S2 = periodTable.reset_index()
    a = lambda x: trialCollapse(x, colNames, trial_agg)
    trialTable = S2.groupby(['MouseID', 'TrialCnt']).apply(a)
    trialTable['Precue_ResponseRate'] = trialTable['Precue_NosePoke_OnSet'] / trialTable['Precue_Duration']
    trialTable['Precue_ResponseRate'] = trialTable['Precue_ResponseRate'].astype(float) 

    #Add a counting number of Trials
    trialTable['nTrials'] = 0
    extperiod_agg['nTrials'] = 'count'
    extTrialTable, rrT = calculateExtendedTable(trialTable, ['MouseID', 'ExtendedTrialType'], extperiod_agg)
    extTrialTable2, _ = calculateExtendedTable(trialTable, ['MouseID', 'TrialType'], extperiod_agg)
    extTrialTable3, _ = calculateExtendedTable(trialTable, ['MouseID'], extperiod_agg)

    #Copy the responseRate to all tables
    extTrialTable2 = extTrialTable2.drop(['TrialType'], axis=1).reset_index().set_index('MouseID')
    extTrialTable2.update(rrT)
    extTrialTable2.reset_index().set_index(['MouseID', 'TrialType'])
    extTrialTable3.update(rrT)

    #Add the ExtendedTrial column to the original Table
    T2 = pd.merge(Table.reset_index(), trialTable.ExtendedTrialType.reset_index().drop('TrialType', axis=1), on=['MouseID', 'TrialCnt'], how='inner').set_index('Time')

    return periodTable, trialTable, [extTrialTable, extTrialTable2, extTrialTable3], T2

